{
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# Copyright 2019 The Google Research Authors.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "import sys\n",
        "sys.path.append(\"..\")\n",
        "import os\n",
        "import json\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import functools\n",
        "from dqn import molecules\n",
        "from dqn import deep_q_networks\n",
        "from dqn.py.SA_Score import sascorer\n",
        "\n",
        "from rdkit import Chem, DataStructs\n",
        "from rdkit.Chem import AllChem, Draw, Descriptors, QED\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import tensorflow as tf\n",
        "from pathlib import Path"
      ],
      "outputs": [],
      "execution_count": null,
      "metadata": {
        "inputHidden": false,
        "outputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def latest_ckpt(path):\n",
        "    return max([int(p.stem.split('-')[1]) for p in path.iterdir() if p.stem[:4] == 'ckpt'])"
      ],
      "outputs": [],
      "execution_count": null,
      "metadata": {}
    },
    {
      "cell_type": "code",
      "source": [
        "basepath = '/Users/odin/sherlock_scratch/moldqn2/target_sas/mol%i_target_%.1f'\n",
        "path = Path(basepath %(1, 4.8))\n",
        "latest_ckpt(path)"
      ],
      "outputs": [],
      "execution_count": null,
      "metadata": {}
    },
    {
      "cell_type": "code",
      "source": [
        "all_molecules = [\"CCCN(C)N=Nc1ccc(cc1)C(=O)O\",\n",
        " \"CN1CCC[C@H]1c2cccnc2\",\n",
        " \"CCCCCC(O)c1cccc(OCc2cccc(c2)C(=O)OC)c1\",\n",
        " \"CCc1c(C)[nH]c2CCC(CN3CCOCC3)C(=O)c12\",\n",
        " \"COc1cc(cc(OC)c1OC)C(=O)N2CCN(C(COC(=O)CC(C)(C)C)C2)C(=O)c3cc(OC)c(OC)c(OC)c3\",\n",
        " \"Cc1cc(C)cc(c1)N2C(=O)Cc3ccccc3C2=O\",\n",
        " \"CCCCCCCCCCCCc1ccc(OCCCC(C)(C)C(=O)O)cc1OCCCC(C)(C)C(=O)O\",\n",
        " \"COc1ccc(C[C@@H](C)NC[C@H](O)c2ccc(O)c(NC=O)c2)cc1\",\n",
        " \"CC12CC3CC(C)(C1)CC(N)(C3)C2\",\n",
        " \"CC(C)NCC(O)COC(=O)c1ccc(NC(=O)C)cc1\",\n",
        " \"CCC1=C(CNC1=O)c2ccc(cc2)n3ccnc3\",\n",
        " \"CN(C)CCCn1cc(C2=C(C(=O)NC2=O)c3cn(CCOCCO)c4ccccc34)c5ccccc15\",\n",
        " \"Cc1c(ccc2nc(N)nc(N)c12)C(=O)NC(CCC(=O)O)C(=O)O\",\n",
        " \"COc1cc2nc(nc(N)c2cc1OC)N3CCN(CC3)C(=O)Nc4ccccc4\",\n",
        " \"O=C1CCC(N2C(=O)c3ccccc3C2=O)C(=O)N1\",\n",
        " \"OC(=O)c1ccccc1\",\n",
        " \"CCCCCCCC1=CC(=CC(=O)O1)OC\",\n",
        " \"CCOc1ccc2c(c1)c(CCNC(=O)C3CC3)c4c5ccccc5CCCn24\",\n",
        " \"CC(Cc1ccc(O)c(O)c1)C(C)Cc2ccc(O)c(O)c2\",\n",
        " \"CCN1c2ccccc2Cc3c(O)ncnc13\",\n",
        " \"CN1C(=O)C2(OCCO2)c3ccccc13\",\n",
        " \"CN1C(=O)NC2=C(N(C)C(=O)N2)C1=O\",\n",
        " \"CN(C)c1ncnc2c1ncn2Cc3cccc(C)c3\",\n",
        " \"Cc1cccc(CC2CCc3nc(N)nc(N)c3C2)c1\",\n",
        " \"CC(C)NCC(O)COC(=O)c1ccc(CO)cc1\",\n",
        " \"OC(=O)CN(CCN(CC(=O)O)CC(=O)O)CC(=O)O\",\n",
        " \"CN(CC=C)CC(N(C)CC=C)C(=O)Nc1c(C)cccc1C\",\n",
        " \"OCCCc1cc2OCCc2cc1O\",\n",
        " \"Cc1cc(CCCCCCCOc2ccc(cc2)C3=NC(C)(C)CO3)on1\",\n",
        " \"CCCCCCN1CCN2CC(c3ccccc3)c4ccccc4C2C1\",\n",
        " \"COc1c2OC(=O)C=Cc2c(COCCCO)c3ccoc13\"]"
      ],
      "outputs": [],
      "execution_count": null,
      "metadata": {}
    },
    {
      "cell_type": "code",
      "source": [
        "def eval(model_dir, idx):\n",
        "  ckpt = latest_ckpt(Path(model_dir))\n",
        "  hparams_file = os.path.join(model_dir, 'config.json')\n",
        "  try:\n",
        "    fh = open(hparams_file, 'r')\n",
        "  except FileNotFoundError:\n",
        "    fh = open('/Users/odin/sherlock_scratch/moldqn2/target_sas/config.json', 'r')\n",
        "  hp_dict = json.load(fh)\n",
        "  hparams = deep_q_networks.get_hparams(**hp_dict)\n",
        "  fh.close()\n",
        "\n",
        "  environment = molecules.Molecule(\n",
        "      atom_types=set(hparams.atom_types),\n",
        "      init_mol=all_molecules[idx],\n",
        "      allow_removal=hparams.allow_removal,\n",
        "      allow_no_modification=hparams.allow_no_modification,\n",
        "      allowed_ring_sizes=set(hparams.allowed_ring_sizes),\n",
        "      allow_bonds_between_rings=hparams.allow_bonds_between_rings,\n",
        "      max_steps=hparams.max_steps_per_episode)\n",
        "\n",
        "  dqn = deep_q_networks.DeepQNetwork(\n",
        "      input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),\n",
        "      q_fn=functools.partial(\n",
        "          deep_q_networks.multi_layer_model, hparams=hparams),\n",
        "      optimizer=hparams.optimizer,\n",
        "      grad_clipping=hparams.grad_clipping,\n",
        "      num_bootstrap_heads=hparams.num_bootstrap_heads,\n",
        "      gamma=hparams.gamma,\n",
        "      epsilon=0.0)\n",
        "  \n",
        "  tf.reset_default_graph()\n",
        "  with tf.Session() as sess:\n",
        "    dqn.build()\n",
        "    model_saver = tf.train.Saver(max_to_keep=hparams.max_num_checkpoints)\n",
        "    model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt))\n",
        "    environment.initialize()\n",
        "    for step in range(hparams.max_steps_per_episode):\n",
        "      steps_left = hparams.max_steps_per_episode - environment.num_steps_taken\n",
        "      \n",
        "      if hparams.num_bootstrap_heads:\n",
        "        head = np.random.randint(hparams.num_bootstrap_heads)\n",
        "      else:\n",
        "        head = 0\n",
        "      valid_actions = list(environment.get_valid_actions())\n",
        "      observations = np.vstack(\n",
        "        [np.append(deep_q_networks.get_fingerprint(act, hparams), steps_left) \n",
        "         for act in valid_actions])\n",
        "      action = valid_actions[dqn.get_action(\n",
        "          observations, head=head, update_epsilon=0.0)]\n",
        "      result = environment.step(action)\n",
        "  return ckpt, result\n"
      ],
      "outputs": [],
      "execution_count": null,
      "metadata": {}
    },
    {
      "cell_type": "code",
      "source": [
        "all_results = []\n",
        "for i in range(31):\n",
        "    for target in (2.5, 4.8):\n",
        "        ckpt, result = eval(basepath %(i, target), i)\n",
        "        ori_sas = sascorer.calculateScore(Chem.MolFromSmiles(all_molecules[i]))\n",
        "        sas = sascorer.calculateScore(Chem.MolFromSmiles(result.state))\n",
        "        all_results.append((i, ckpt, all_molecules[i], result.state, ori_sas, target, sas))"
      ],
      "outputs": [],
      "execution_count": null,
      "metadata": {
        "scrolled": true
      }
    },
    {
      "cell_type": "code",
      "source": [
        "df = pd.DataFrame(all_results, columns=['index', 'ckpt', 'original_molecule', 'generated_molecule', 'original_sas', 'target_sas', 'sas'])\n",
        "df.to_csv('target_sas_results.csv')"
      ],
      "outputs": [],
      "execution_count": null,
      "metadata": {}
    },
    {
      "cell_type": "code",
      "source": [
        "plt.figure()\n",
        "df25 = df[df['target_sas'] == 2.5]\n",
        "x25 = df25['original_sas']\n",
        "y25 = df25['sas']\n",
        "plt.scatter(x25, y25, label='target_sas=2.5')\n",
        "\n",
        "df48 = df[df['target_sas'] == 4.8]\n",
        "x48 = df48['original_sas']\n",
        "y48 = df48['sas']\n",
        "plt.scatter(x48, y48, label='target_sas=4.8')\n",
        "\n",
        "plt.legend()\n",
        "plt.show()"
      ],
      "outputs": [],
      "execution_count": null,
      "metadata": {}
    }
  ],
  "metadata": {
    "kernel_info": {
      "name": "python3"
    },
    "kernelspec": {
      "name": "python3",
      "language": "python",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python",
      "version": "3.6.7",
      "mimetype": "text/x-python",
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "pygments_lexer": "ipython3",
      "nbconvert_exporter": "python",
      "file_extension": ".py"
    },
    "nteract": {
      "version": "0.14.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}